/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/
#ifndef itkExpectationBasedPointSetToPointSetMetricv4_hxx
#define itkExpectationBasedPointSetToPointSetMetricv4_hxx

#include "itkArray.h"
#include "itkCompensatedSummation.h"

namespace itk
{

template <typename TFixedPointSet, typename TMovingPointSet, class TInternalComputationValueType>
ExpectationBasedPointSetToPointSetMetricv4<TFixedPointSet, TMovingPointSet, TInternalComputationValueType>::
  ExpectationBasedPointSetToPointSetMetricv4()
  : m_PointSetSigma(1.0)
  , m_PreFactor(0.0)
  , m_Denominator(0.0)
{}

template <typename TFixedPointSet, typename TMovingPointSet, class TInternalComputationValueType>
void
ExpectationBasedPointSetToPointSetMetricv4<TFixedPointSet, TMovingPointSet, TInternalComputationValueType>::Initialize()
{
  Superclass::Initialize();

  if (this->m_PointSetSigma <= NumericTraits<CoordinateType>::epsilon())
  {
    itkExceptionStringMacro("m_PointSetSigma is too small. <= epsilon");
  }
  this->m_PreFactor = 1.0 / (std::sqrt(2 * itk::Math::pi) * this->m_PointSetSigma);
  this->m_Denominator = 2.0 * itk::Math::sqr(this->m_PointSetSigma);
}

template <typename TFixedPointSet, typename TMovingPointSet, class TInternalComputationValueType>
typename ExpectationBasedPointSetToPointSetMetricv4<TFixedPointSet, TMovingPointSet, TInternalComputationValueType>::
  MeasureType
  ExpectationBasedPointSetToPointSetMetricv4<TFixedPointSet, TMovingPointSet, TInternalComputationValueType>::
    GetLocalNeighborhoodValue(const PointType & point, const PixelType & itkNotUsed(pixel)) const
{
  CompensatedSummation<MeasureType> localValue;

  NeighborsIdentifierType neighborhood;
  this->m_MovingTransformedPointsLocator->FindClosestNPoints(point, this->m_EvaluationKNeighborhood, neighborhood);

  for (auto it = neighborhood.begin(); it != neighborhood.end(); ++it)
  {
    const PointType   neighbor = this->m_MovingTransformedPointSet->GetPoint(*it);
    const MeasureType distance = point.SquaredEuclideanDistanceTo(neighbor);
    localValue -= this->m_PreFactor * std::exp(-distance / this->m_Denominator);
  }

  return localValue.GetSum();
}

template <typename TFixedPointSet, typename TMovingPointSet, class TInternalComputationValueType>
void
ExpectationBasedPointSetToPointSetMetricv4<TFixedPointSet, TMovingPointSet, TInternalComputationValueType>::
  GetLocalNeighborhoodValueAndDerivative(const PointType &     point,
                                         MeasureType &         measure,
                                         LocalDerivativeType & localDerivative,
                                         const PixelType &     itkNotUsed(pixel)) const
{
  Array<MeasureType> measureValues;
  measureValues.SetSize(this->m_EvaluationKNeighborhood);
  measureValues.Fill(0.0);

  CompensatedSummation<MeasureType> measureSum;

  localDerivative.Fill(0.0);

  PointType weightedPoint{};

  NeighborsIdentifierType neighborhood;

  this->m_MovingTransformedPointsLocator->FindClosestNPoints(point, this->m_EvaluationKNeighborhood, neighborhood);

  for (auto it = neighborhood.begin(); it != neighborhood.end(); ++it)
  {
    const PointType   neighbor = this->m_MovingTransformedPointSet->GetPoint(*it);
    const MeasureType distance = point.SquaredEuclideanDistanceTo(neighbor);
    measureValues[it - neighborhood.begin()] = -this->m_PreFactor * std::exp(-distance / this->m_Denominator);
    measureSum += measureValues[it - neighborhood.begin()];
  }

  measure = measureSum.GetSum();
  if (itk::Math::abs(measure) <= NumericTraits<MeasureType>::epsilon())
  {
    return;
  }

  for (auto it = neighborhood.begin(); it != neighborhood.end(); ++it)
  {
    const PointType  neighbor = this->m_MovingTransformedPointSet->GetPoint(*it);
    const VectorType neighborVector = neighbor.GetVectorFromOrigin();
    weightedPoint += (neighborVector * measureValues[it - neighborhood.begin()] / measure);
  }

  const MeasureType distance = point.SquaredEuclideanDistanceTo(weightedPoint);

  const MeasureType weight = this->m_PreFactor * std::exp(-distance / this->m_Denominator) / -measure;

  VectorType force = (weightedPoint - point) * weight;

  for (unsigned int d = 0; d < localDerivative.Size(); ++d)
  {
    localDerivative[d] = force[d];
  }
}

template <typename TFixedPointSet, typename TMovingPointSet, class TInternalComputationValueType>
typename LightObject::Pointer
ExpectationBasedPointSetToPointSetMetricv4<TFixedPointSet, TMovingPointSet, TInternalComputationValueType>::
  InternalClone() const
{
  auto rval = Self::New();
  rval->SetMovingPointSet(this->m_MovingPointSet);
  rval->SetFixedPointSet(this->m_FixedPointSet);
  rval->SetPointSetSigma(this->m_PointSetSigma);
  rval->SetEvaluationKNeighborhood(this->m_EvaluationKNeighborhood);

  return rval.GetPointer();
}

template <typename TFixedPointSet, typename TMovingPointSet, class TInternalComputationValueType>
void
ExpectationBasedPointSetToPointSetMetricv4<TFixedPointSet, TMovingPointSet, TInternalComputationValueType>::PrintSelf(
  std::ostream & os,
  Indent         indent) const
{
  Superclass::PrintSelf(os, indent);

  os << indent << "PointSetSigma: " << this->m_PointSetSigma << std::endl;
  os << indent << "EvaluateKNeighborhood: " << this->m_EvaluationKNeighborhood << std::endl;
}
} // end namespace itk


#endif
